<!-- This demo is part of Burn project: https://github.com/tracel-ai/burn

    Released under a dual license: 
    https://github.com/tracel-ai/burn/blob/main/LICENSE-MIT

    https://github.com/tracel-ai/burn/blob/main/LICENSE-APACHE
-->
<!DOCTYPE html>
<html>
  <head>
    <meta charset="utf-8" />
    <title>Burn MNIST Inference Web Demo</title>

    <script
      src="https://cdn.jsdelivr.net/npm/fabric@5.3.0/dist/fabric.min.js"
      integrity="sha256-SPjwkVvrUS/H/htIwO6wdd0IA8eQ79/XXNAH+cPuoso="
      crossorigin="anonymous"
    ></script>

    <script
      src="https://cdn.jsdelivr.net/npm/chart.js@4.2.1/dist/chart.umd.min.js"
      integrity="sha256-tgiW1vJqfIKxE0F2uVvsXbgUlTyrhPMY/sm30hh/Sxc="
      crossorigin="anonymous"
    ></script>

    <script
      src="https://cdn.jsdelivr.net/npm/chartjs-plugin-datalabels@2.2.0/dist/chartjs-plugin-datalabels.min.js"
      integrity="sha256-IMCPPZxtLvdt9tam8RJ8ABMzn+Mq3SQiInbDmMYwjDg="
      crossorigin="anonymous"
    ></script>

    <link
      rel="stylesheet"
      href="https://cdn.jsdelivr.net/npm/normalize.min.css@8.0.1/normalize.min.css"
      integrity="sha256-oeib74n7OcB5VoyaI+aGxJKkNEdyxYjd2m3fi/3gKls="
      crossorigin="anonymous"
    />

    <style>
      h1 {
        padding: 15px;
      }
      th,
      td {
        padding: 5px;
        text-align: center;
        vertical-align: middle;
      }
    </style>
  </head>
  <body>
    <h1>Burn MNIST Inference Demo</h1>

    <table>
      <tr>
        <th>Draw a digit here</th>
        <th>Cropped and scaled</th>
        <th>Probability result</th>
      </tr>
      <tr>
        <td>
          <canvas id="main-canvas" width="300" height="300" style="border: 1px solid #aaa"></canvas>
        </td>
        <td>
          <canvas
            id="scaled-canvas"
            width="28"
            height="28"
            style="border: 1px solid #aaa; width: 100px; height: 100px"
          ></canvas>
          <canvas id="crop-canvas" width="28" height="28" style="display: none"></canvas>
        </td>
        <td>
          <canvas id="chart" style="border: 1px solid #aaa; width: 600px; height: 300px"></canvas>
        </td>
      </tr>
      <tr>
        <td><button id="clear">Clear</button></td>
        <td></td>
        <td></td>
      </tr>
    </table>

    <div></div>

    <script type="module">
      import { $, cropScaleGetImageData, toFixed, chartConfigBuilder } from "./index.js";

      import { default as wasm, Mnist } from "./pkg/mnist_inference_web.js";

      const chart = chartConfigBuilder($("chart"));

      const mainCanvasEl = $("main-canvas");
      const scaledCanvasEl = $("scaled-canvas");
      const cropEl = $("crop-canvas");
      const mainContext = mainCanvasEl.getContext("2d", { willReadFrequently: true });
      const cropContext = cropEl.getContext("2d", { willReadFrequently: true });
      const scaledContext = scaledCanvasEl.getContext("2d", { willReadFrequently: true });

      const fabricCanvas = new fabric.Canvas(mainCanvasEl, {
        isDrawingMode: true,
      });

      const backgroundColor = "rgba(255, 255, 255, 255)"; // White with solid alpha
      fabricCanvas.freeDrawingBrush.width = 25;
      fabricCanvas.backgroundColor = backgroundColor;

      $("clear").onclick = function () {
        fabricCanvas.clear();
        fabricCanvas.backgroundColor = backgroundColor;
        fabricCanvas.renderAll();
        mainContext.clearRect(0, 0, mainCanvasEl.width, mainCanvasEl.height);
        scaledContext.clearRect(0, 0, scaledCanvasEl.width, scaledCanvasEl.height);

        chart.data.datasets[0].data = [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0];
        chart.update();
      };

      let timeoutId;
      let isDrawing = false;
      let isTimeOutSet = false;

      wasm().then((module) => {
        const mnist = new Mnist();

        async function fireOffInference() {
          clearTimeout(timeoutId);
          timeoutId = setTimeout(async () => {
            isTimeOutSet = true;
            fabricCanvas.freeDrawingBrush._finalizeAndAddPath();
            const data = cropScaleGetImageData(mainContext, cropContext, scaledContext);
            const output = await mnist.inference(data);
            chart.data.datasets[0].data = output;
            chart.update();
            isTimeOutSet = false;
          }, 50);
          isTimeOutSet = true;
        }

        fabricCanvas.on("mouse:down", function (event) {
          isDrawing = true;
        });
        fabricCanvas.on("mouse:up", async function (event) {
          isDrawing = false;
          await fireOffInference();
        });

        fabricCanvas.on("mouse:move", async function (event) {
          if (isDrawing && isTimeOutSet == false) {
            await fireOffInference();
          }
        });
      });
    </script>
  </body>
</html>
